{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benign Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from advsecurenet.models.model_factory import ModelFactory\n",
    "from advsecurenet.datasets.dataset_factory import DatasetFactory\n",
    "from advsecurenet.dataloader.data_loader_factory import DataLoaderFactory\n",
    "from advsecurenet.shared.types.configs.preprocess_config import (\n",
    "    PreprocessConfig,\n",
    "    PreprocessStep,\n",
    ")\n",
    "from advsecurenet.shared.types.configs.device_config import DeviceConfig\n",
    "from advsecurenet.shared.types.configs import TrainConfig\n",
    "from advsecurenet.trainer.trainer import Trainer\n",
    "from advsecurenet.shared.types.configs.model_config import CreateModelConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the model\n",
    "model = ModelFactory.create_model(\n",
    "    model_name=\"resnet18\", num_classes=10, pretrained=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets define the preprocessing configuration we want to use\n",
    "preprocess_config = PreprocessConfig(\n",
    "    steps=[\n",
    "        PreprocessStep(name=\"Resize\", params={\"size\": 32}),\n",
    "        PreprocessStep(name=\"CenterCrop\", params={\"size\": 32}),\n",
    "        PreprocessStep(name=\"ToTensor\"),\n",
    "        PreprocessStep(\n",
    "            name=\"ToDtype\", params={\"dtype\": \"torch.float32\", \"scale\": True}\n",
    "        ),\n",
    "        PreprocessStep(\n",
    "            name=\"Normalize\",\n",
    "            params={\"mean\": [0.485, 0.456, 0.406], \"std\": [0.229, 0.224, 0.225]},\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Define the dataset\n",
    "dataset = DatasetFactory.create_dataset(\n",
    "    dataset_type=\"cifar10\", preprocess_config=preprocess_config, return_loaded=False\n",
    ")\n",
    "train_data = dataset.load_dataset(train=True)\n",
    "test_data = dataset.load_dataset(train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the dataloder\n",
    "dataloader = DataLoaderFactory.create_dataloader(dataset=train_data, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the training config\n",
    "config = TrainConfig(\n",
    "    model=model,\n",
    "    train_loader=dataloader,\n",
    "    epochs=2,\n",
    "    processor=\"mps\",  # Set this to your desired processor i.e. \"cpu\", \"gpu\", \"mps\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the trainer\n",
    "trainer = Trainer(config)\n",
    "\n",
    "# Train the model\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using an External Model\n",
    "\n",
    "It's also possible to use a custom external model in the `advsecurenet` library. The following example shows how to load an external model and train it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = CreateModelConfig(\n",
    "    model_name=\"Net\",\n",
    "    model_arch_path=\"./external_model.py\",\n",
    "    num_classes=10,\n",
    "    pretrained=False,\n",
    "    is_external=True,\n",
    ")\n",
    "\n",
    "external_model = ModelFactory.create_model(model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if the model is loaded\n",
    "print(external_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# update the training config\n",
    "config = TrainConfig(\n",
    "    model=external_model,\n",
    "    train_loader=dataloader,\n",
    "    epochs=1,\n",
    "    processor=\"mps\",  # Set this to your desired processor i.e. \"cpu\", \"gpu\", \"mps\"\n",
    ")\n",
    "\n",
    "# Create the trainer\n",
    "trainer = Trainer(config)\n",
    "\n",
    "# Train the model\n",
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "adv2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.1.-1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
